import flowvision
import os


class SUN397(flowvision.datasets.ImageFolder):
    def __init__(self, root, split='train'):
        self.split = split
        self.num_classes = 397
        path = os.path.join(root, 'train') if self.split=='train' else os.path.join(root, 'test')
        super().__init__(path)
        print('SUN397, Split: %s, Size: %d' % (self.split, len(self.imgs)))
